import torch
from train import *
import numpy as np
import pandas as pd

# 加载模型
def load_model():
    model = TextClassification(411,128,256,2,False,0.2,4)
    model.load_state_dict(torch.load('model.pkl'))
    return model

#预测
def predict(string):
    model = load_model()
    index = np.array(seq2index(string))
    index = index[np.newaxis, :]
    x = padding_seq(index)
    x = torch.from_numpy(x)
    y = model(x)
    p = torch.argmax(y[0]).data.numpy()
    return p


if __name__ == '__main__':
    label = predict('华润万家的企业资源项编号是什么')
    print(label)

